{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tlvY59v6PjvC"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "jiFUWa49Pl1F"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XjJkne25Prps"
      },
      "source": [
        "# Fine-tuning Gemma for Function Calling\n",
        "\n",
        "Welcome to this step-by-step guide on fine-tuning the [Gemma](https://huggingface.co/google/gemma-2b) for Function Calling.\n",
        "\n",
        "\n",
        "[**Gemma**](https://ai.google.dev/gemma) is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop, or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "**Function calling finetuning** is a crucial step in enhancing the performance of LLMs with function calling capabilities. It involves training the model on a dataset of prompts and corresponding function calls, enabling it to accurately identify the appropriate function for a given task. By fine-tuning the model, it learns to better understand the nuances of natural language, recognize the intent behind prompts, and select the most suitable functions.\n",
        "\n",
        "This notebook uses [Torch XLA](https://github.com/pytorch/xla) and Hugging Face's [**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) framework for Function calling finetuning.\n",
        "\n",
        "[**Torch XLA**](https://pytorch.org/xla/) enables you to leverage the computational power of TPUs (Tensor Processing Units) for efficient training of deep learning models. By interfacing PyTorch with the [XLA (Accelerated Linear Algebra)](https://openxla.org/xla) compiler, Torch XLA translates PyTorch operations into XLA operations that can be executed on TPUs. This means you can write your models in PyTorch as usual, and Torch XLA handles the underlying computations to run them efficiently on TPUs.\n",
        "\n",
        "**Transformer Reinforcement Learning (TRL)**](https://github.com/huggingface/trl) is a framework developed by Hugging Face to fine-tune and align both transformer language and diffusion models using methods such as Supervised Fine-Tuning (SFT), Reward Modeling (RM), Proximal Policy Optimization (PPO), Direct Preference Optimization (DPO), and others.\n",
        "\n",
        "To know more about how to use Torch XLA and TRL to finetune Gemma, check the **Finetune with Torch XLA** notebook from [Gemma Cookbook](https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/Finetune_with_Torch_XLA.ipynb).\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Function_Calling.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>\n",
        "<br><br>\n",
        "\n",
        "[![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](\"https://www.kaggle.com/notebooks/welcome?src=https://github.com/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Finetune_with_Function_Calling.ipynb\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QdM3rXG4U-mb"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Selecting the Runtime Environment\n",
        "\n",
        "To start, you can choose either **Google Colab** or **Kaggle** as your platform. Select one, and proceed from there.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "\n",
        "  1. Click **Open in Colab**.\n",
        "  2. In the menu, go to **Runtime** > **Change runtime type**.\n",
        "  3. Under **Hardware accelerator**, select **TPU**.\n",
        "  4. Ensure that the **TPU type** is set to **TPU v2-8**.\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "\n",
        "  1. Click **Open in Kaggle**.\n",
        "  2. Click on **Settings** in the right sidebar.\n",
        "  3. Under **Accelerator**, select **TPUs**.\n",
        "    - Note: Kaggle currently provides **TPU v3-8**.\n",
        "  4. Save the settings, and the notebook will restart with TPU support.\n",
        "\n",
        "\n",
        "### Gemma using Hugging Face\n",
        "\n",
        "Before diving into the tutorial, let's set up Gemma:\n",
        "\n",
        "1. **Create a Hugging Face Account**: If you don't have one, you can sign up for a free account [here](https://huggingface.com/join).\n",
        "2. **Access the Gemma Model**: Visit the [Gemma model page](https://huggingface.com/collections/google/gemma-2-release-667d6600fd5220e7b967f315) and accept the usage conditions.\n",
        "3. **Generate a Hugging Face Token**: Go to your Hugging Face [settings page](https://huggingface.com/settings/tokens) and generate a new access token (preferably with `write` permissions). You'll need this token later in the tutorial.\n",
        "\n",
        "**Once you've completed these steps, you're ready to move on to the next section where you'll set up environment variables in your Colab environment.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3pjmx0qYVI5_"
      },
      "source": [
        "### Configure Your Credentials\n",
        "\n",
        "To access private models and datasets, you need to log in to the Hugging Face (HF) ecosystem.\n",
        "\n",
        "- #### **Google Colab** <img src=\"https://upload.wikimedia.org/wikipedia/commons/thumb/d/d0/Google_Colaboratory_SVG_Logo.svg/1200px-Google_Colaboratory_SVG_Logo.svg.png\" alt=\"Google Colab\" width=\"30\"/>\n",
        "  If you're using Colab, you can securely store your Hugging Face token (`HF_TOKEN`) using the Colab Secrets manager:\n",
        "  1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "  2. **Add Hugging Face Token**:\n",
        "    - Create a new secret with the **name** `HF_TOKEN`.\n",
        "    - Copy/paste your token key into the **Value** input box of `HF_TOKEN`.\n",
        "    - **Toggle** the button on the left to allow notebook access to the secret\n",
        "\n",
        "- #### **Kaggle** <img src=\"https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png\" alt=\"Kaggle\" width=\"40\"/>\n",
        "  To securely use your Hugging Face token (`HF_TOKEN`) in this notebook, you'll need to add it as a secret in your Kaggle environment:  \n",
        "  1. Open your Kaggle notebook and locate the **Addons** menu at the top in your notebook interface.\n",
        "  2. Click on **Secrets** to manage your environment secrets.  \n",
        "  <img src=\"https://i.imgur.com/vxrtJuM.png\" alt=\"The Secrets option is found at the top.\" width=50%>\n",
        "  3. **Add Hugging Face Token**:\n",
        "      - Click on the **Add secret** button.\n",
        "      - In the **Label** field, enter `HF_TOKEN`.  \n",
        "      - In the **Value** field, paste your Hugging Face token.\n",
        "      - Click **Save** to add the secret."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V7GXw4tjVS0H"
      },
      "source": [
        "This code retrieves your secrets and sets them as environment variables, which you will use later in the tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j6tFvOkyZz_o"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "\n",
        "if 'google.colab' in sys.modules:\n",
        "    # Running on Colab\n",
        "    from google.colab import userdata\n",
        "    os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')\n",
        "elif os.path.exists('/kaggle/working'):\n",
        "    # Running on Kaggle\n",
        "    from kaggle_secrets import UserSecretsClient\n",
        "    user_secrets = UserSecretsClient()\n",
        "    os.environ['HF_TOKEN'] = user_secrets.get_secret(\"HF_TOKEN\")\n",
        "else:\n",
        "    # Not running on Colab or Kaggle\n",
        "    raise EnvironmentError('This notebook is designed to run on Google Colab or Kaggle.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_zS6gPp1VgPd"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "Next, you'll set up the environment by installing all the necessary Python packages for fine-tuning the Gemma model on a TPU VM using Torch XLA.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bVaISQ8GZz_s"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Found existing installation: tensorflow 2.15.0\n",
            "Uninstalling tensorflow-2.15.0:\n",
            "  Successfully uninstalled tensorflow-2.15.0\n",
            "Found existing installation: tf_keras 2.15.1\n",
            "Uninstalling tf_keras-2.15.1:\n",
            "  Successfully uninstalled tf_keras-2.15.1\n",
            "Collecting tensorflow==2.18.0\n",
            "  Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)\n",
            "Collecting tf-keras==2.18.0\n",
            "  Downloading tf_keras-2.18.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: absl-py>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.4.0)\n",
            "Requirement already satisfied: astunparse>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.6.3)\n",
            "Requirement already satisfied: flatbuffers>=24.3.25 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (24.3.25)\n",
            "Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.6.0)\n",
            "Requirement already satisfied: google-pasta>=0.1.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.2.0)\n",
            "Requirement already satisfied: libclang>=13.0.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (18.1.1)\n",
            "Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (3.4.0)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (24.2)\n",
            "Requirement already satisfied: protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<6.0.0dev,>=3.20.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (4.25.5)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (2.32.3)\n",
            "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (75.1.0)\n",
            "Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.16.0)\n",
            "Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (2.5.0)\n",
            "Requirement already satisfied: typing-extensions>=3.6.6 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (4.12.2)\n",
            "Requirement already satisfied: wrapt>=1.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.14.1)\n",
            "Requirement already satisfied: grpcio<2.0,>=1.24.3 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.68.1)\n",
            "Collecting tensorboard<2.19,>=2.18 (from tensorflow==2.18.0)\n",
            "  Downloading tensorboard-2.18.0-py3-none-any.whl.metadata (1.6 kB)\n",
            "Collecting keras>=3.5.0 (from tensorflow==2.18.0)\n",
            "  Downloading keras-3.7.0-py3-none-any.whl.metadata (5.8 kB)\n",
            "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (1.26.4)\n",
            "Requirement already satisfied: h5py>=3.11.0 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (3.12.1)\n",
            "Collecting ml-dtypes<0.5.0,>=0.4.0 (from tensorflow==2.18.0)\n",
            "  Downloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (20 kB)\n",
            "Requirement already satisfied: tensorflow-io-gcs-filesystem>=0.23.1 in /usr/local/lib/python3.10/dist-packages (from tensorflow==2.18.0) (0.37.1)\n",
            "Requirement already satisfied: wheel<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from astunparse>=1.6.0->tensorflow==2.18.0) (0.45.1)\n",
            "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from keras>=3.5.0->tensorflow==2.18.0) (13.9.4)\n",
            "Collecting namex (from keras>=3.5.0->tensorflow==2.18.0)\n",
            "  Downloading namex-0.0.8-py3-none-any.whl.metadata (246 bytes)\n",
            "Collecting optree (from keras>=3.5.0->tensorflow==2.18.0)\n",
            "  Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (47 kB)\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m47.8/47.8 kB\u001b[0m \u001b[31m1.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (3.4.0)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (2.2.3)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorflow==2.18.0) (2024.8.30)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.7)\n",
            "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (0.7.2)\n",
            "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.1.3)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard<2.19,>=2.18->tensorflow==2.18.0) (3.0.2)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow==2.18.0) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->keras>=3.5.0->tensorflow==2.18.0) (2.18.0)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->keras>=3.5.0->tensorflow==2.18.0) (0.1.2)\n",
            "Downloading tensorflow-2.18.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (615.3 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m615.3/615.3 MB\u001b[0m \u001b[31m1.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tf_keras-2.18.0-py3-none-any.whl (1.7 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m27.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading keras-3.7.0-py3-none-any.whl (1.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.2/1.2 MB\u001b[0m \u001b[31m44.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading ml_dtypes-0.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.2 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m55.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading tensorboard-2.18.0-py3-none-any.whl (5.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.5/5.5 MB\u001b[0m \u001b[31m85.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hDownloading namex-0.0.8-py3-none-any.whl (5.8 kB)\n",
            "Downloading optree-0.13.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (381 kB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m381.3/381.3 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: namex, optree, ml-dtypes, tensorboard, keras, tensorflow, tf-keras\n",
            "  Attempting uninstall: ml-dtypes\n",
            "    Found existing installation: ml-dtypes 0.2.0\n",
            "    Uninstalling ml-dtypes-0.2.0:\n",
            "      Successfully uninstalled ml-dtypes-0.2.0\n",
            "  Attempting uninstall: tensorboard\n",
            "    Found existing installation: tensorboard 2.15.2\n",
            "    Uninstalling tensorboard-2.15.2:\n",
            "      Successfully uninstalled tensorboard-2.15.2\n",
            "  Attempting uninstall: keras\n",
            "    Found existing installation: keras 2.15.0\n",
            "    Uninstalling keras-2.15.0:\n",
            "      Successfully uninstalled keras-2.15.0\n",
            "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
            "tensorflow-text 2.15.0 requires tensorflow<2.16,>=2.15.0; platform_machine != \"arm64\" or platform_system != \"Darwin\", but you have tensorflow 2.18.0 which is incompatible.\u001b[0m\u001b[31m\n",
            "\u001b[0mSuccessfully installed keras-3.7.0 ml-dtypes-0.4.1 namex-0.0.8 optree-0.13.1 tensorboard-2.18.0 tensorflow-2.18.0 tf-keras-2.18.0\n",
            "Found existing installation: tensorflow 2.18.0\n",
            "Uninstalling tensorflow-2.18.0:\n",
            "  Successfully uninstalled tensorflow-2.18.0\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m230.0/230.0 MB\u001b[0m \u001b[31m4.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.1/44.1 kB\u001b[0m \u001b[31m1.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m65.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m69.2/69.2 kB\u001b[0m \u001b[31m1.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m1.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m3.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m28.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m241.9/241.9 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.6/124.6 kB\u001b[0m \u001b[31m3.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m205.1/205.1 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m319.7/319.7 kB\u001b[0m \u001b[31m12.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m310.2/310.2 kB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m320.7/320.7 kB\u001b[0m \u001b[31m14.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m324.3/324.3 kB\u001b[0m \u001b[31m5.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hRequirement already satisfied: tpu-info in /usr/local/lib/python3.10/dist-packages (0.2.0)\n",
            "Requirement already satisfied: grpcio>=1.65.5 in /usr/local/lib/python3.10/dist-packages (from tpu-info) (1.68.1)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from tpu-info) (4.25.5)\n",
            "Requirement already satisfied: rich in /usr/local/lib/python3.10/dist-packages (from tpu-info) (13.9.4)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (2.18.0)\n",
            "Requirement already satisfied: typing-extensions<5.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from rich->tpu-info) (4.12.2)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich->tpu-info) (0.1.2)\n"
          ]
        }
      ],
      "source": [
        "# Uninstalling any existing TensorFlow installations and then install the CPU-only version to avoid conflicts while using the TPU.\n",
        "!pip uninstall -y tensorflow tf-keras\n",
        "!pip install tensorflow==2.18.0 tf-keras==2.18.0\n",
        "\n",
        "!pip uninstall tensorflow -y\n",
        "!pip install tensorflow-cpu==2.18.0 -q\n",
        "\n",
        "# Install the appropriate Hugging Face libraries to ensure compatibility with the Gemma model and PEFT.\n",
        "!pip install transformers==4.46.1 -U -q\n",
        "!pip install datasets==3.1.0 -U -q\n",
        "!pip install trl==0.12.0 peft==0.13.2 -U -q\n",
        "!pip install accelerate==0.34.0 -U -q\n",
        "\n",
        "# Install PyTorch and Torch XLA with versions compatible with the TPU runtime, ensuring efficient TPU utilization.\n",
        "!pip install -qq torch~=2.5.0 --index-url https://download.pytorch.org/whl/cpu\n",
        "!pip install -qq torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.html\n",
        "\n",
        "# Install the `tpu-info` package to display TPU-related information\n",
        "!pip install tpu-info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cq1_60_2VoHH"
      },
      "source": [
        "**Note**: Ensure that your PyTorch and Torch XLA versions are compatible with the TPU you're using."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WkAuR_y0VsCH"
      },
      "source": [
        "### Verify TPU Setup\n",
        "\n",
        "You run `!tpu-info` to verify the TPU has been properly initialized."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GVwq7c3mVwyx"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[3mTPU Chips                                     \u001b[0m\n",
            "┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━┓\n",
            "┃\u001b[1m \u001b[0m\u001b[1mChip       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mType       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mDevices\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mPID \u001b[0m\u001b[1m \u001b[0m┃\n",
            "┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━┩\n",
            "│ /dev/accel0 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel1 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel2 │ TPU v2 chip │ 2       │ None │\n",
            "│ /dev/accel3 │ TPU v2 chip │ 2       │ None │\n",
            "└─────────────┴─────────────┴─────────┴──────┘\n",
            "Libtpu metrics unavailable. Is there a framework using the TPU? See https://github.com/google/cloud-accelerator-diagnostics/tree/main/tpu_info for more information\n"
          ]
        }
      ],
      "source": [
        "!tpu-info"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wkakCqYSV1wf"
      },
      "source": [
        "If everything is set up correctly, you should see the TPU details printed out."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rbfldn4KWi5B"
      },
      "source": [
        "## Finetuning Gemma 2 for Function Calling"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KVyMsQc7WoIu"
      },
      "source": [
        "### Initializing Gemma 2 model\n",
        "\n",
        "You will initialize the `AutoModelForCausalLM` from the `transformers` library by loading a pre-trained Gemma 2 model from HuggingFace. You will also initialize the tokenizer for the selected model(`google/gemma-2-2b-it`) using the `AutoTokenizer` from the `transformers` library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TFhbJSgsWR00"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "852b25c573a64363b50028519cda8b33",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/838 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "14411043503f4eb0866b6a996d3d5325",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors.index.json:   0%|          | 0.00/24.2k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "06c8910909274b1484045171537b901c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "69fe4ab92fa14cdda2303c6916c57ca4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "0ee50fa5e7c742dc8f388217826d2c02",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model-00002-of-00002.safetensors:   0%|          | 0.00/241M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c5a5e68ed82d4abf8249a5fa89db70c9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "f7e59e5e10c14500b1020d8134000079",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "generation_config.json:   0%|          | 0.00/187 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b0edf78858b8416baf3f65b5023a145b",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/47.0k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4e3e4d814a9c47bca1e582c2046b165c",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "c080631be64d4e588da086810a493abb",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "75485dda100d4adaaa4473d3607b9f1e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "import torch\n",
        "from transformers import (\n",
        "    AutoTokenizer,\n",
        "    AutoModelForCausalLM,\n",
        ")\n",
        "\n",
        "# Define model names\n",
        "model_name = \"google/gemma-2-2b-it\"\n",
        "new_model = \"gemma-func-ft\"\n",
        "\n",
        "# Load the Gemma pre-trained model\n",
        "model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    torch_dtype=torch.bfloat16\n",
        ")\n",
        "\n",
        "# You must disable the cache to prevent issues during training\n",
        "model.config.use_cache = False\n",
        "\n",
        "# Load the Gemma tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# You adjust the tokenizer's padding side to ensure compatibility during TPU\n",
        "# training.\n",
        "tokenizer.padding_side = \"right\" # Fix overflow issue with bf16/fp16 training"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6GkfQfN-XcQz"
      },
      "source": [
        "Enable Single Program Multiple Data (SPMD) mode,\n",
        "which allows for parallel execution across multiple TPU cores.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mTDNX5tzXhw9"
      },
      "outputs": [],
      "source": [
        "import torch_xla\n",
        "import torch_xla.core.xla_model as xm\n",
        "import torch_xla.runtime as xr\n",
        "\n",
        "xr.use_spmd()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hXN5NwPoXxZV"
      },
      "source": [
        "### Load a dataset\n",
        "\n",
        "For this guide, you'll use an existing dataset from Hugging Face. You can replace it with your dataset if you prefer.\n",
        "\n",
        "The dataset chosen for this guide is [**lilacai/glaive-function-calling-v2-sharegpt**](https://huggingface.co/datasets/lilacai/glaive-function-calling-v2-sharegpt), which is a ShareGPT version of the original **glaive-function-calling-v2** dataset by glaiveai. The glaive-function-calling-v2 dataset is a collection of over 113,000 prompts and corresponding function calls that can fine-tune language models to identify the appropriate function for a given task accurately.\n",
        "\n",
        "**Credits:** **https://huggingface.co/lilacai**"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "--QwWdczZz_v"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "06848119d6804263858d4fb6de1dec45",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/2.51k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "1b3fa19bfc1d4d08977981bf1b9a88bc",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "(…)-00000-of-00002-6f3344faa23e9b0a.parquet:   0%|          | 0.00/98.0M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2ceebc4a913b45cbb7f45e1135648a11",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "(…)-00001-of-00002-41f063cddf49c933.parquet:   0%|          | 0.00/98.5M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6649343ca260492b8ec5febfdeff4c5f",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split:   0%|          | 0/112960 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "from datasets import Dataset, load_dataset\n",
        "\n",
        "# Only the first 15% of the `train` split is used for training. A smaller\n",
        "# subsection of the dataset is selected to avoid out-of-memory crashes.\n",
        "dataset = load_dataset(\"lilacai/glaive-function-calling-v2-sharegpt\", split=\"train[:15%]\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W-EclVpoaQR6"
      },
      "source": [
        "Let's look at a few samples to understand the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6eXNAjx_Zz_0"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[{'from': 'system',\n",
              "  'value': 'You are a helpful assistant with access to the following functions. Use them if required -\\n{\\n    \"name\": \"calculate_discount\",\\n    \"description\": \"Calculate the discount amount based on original price and discount percentage\",\\n    \"parameters\": {\\n        \"type\": \"object\",\\n        \"properties\": {\\n            \"original_price\": {\\n                \"type\": \"number\",\\n                \"description\": \"The original price of the item\"\\n            },\\n            \"discount_percentage\": {\\n                \"type\": \"number\",\\n                \"description\": \"The percentage discount\"\\n            }\\n        },\\n        \"required\": [\\n            \"original_price\",\\n            \"discount_percentage\"\\n        ]\\n    }\\n}\\n'},\n",
              " {'from': 'human',\n",
              "  'value': \"Hi, I saw a dress that I liked in a store. It was originally priced at $200 but it's on a 20% discount. Can you help me calculate how much I will save?\"},\n",
              " {'from': 'gpt',\n",
              "  'value': '<functioncall> {\"name\": \"calculate_discount\", \"arguments\": \\'{\"original_price\": 200, \"discount_percentage\": 20}\\'} <|endoftext|>'},\n",
              " {'from': 'tool', 'value': '{\"discount_amount\": 40}'},\n",
              " {'from': 'gpt',\n",
              "  'value': 'Sure, with a 20% discount on a $200 dress, you will save $40. <|endoftext|>'}]"
            ]
          },
          "execution_count": 8,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset[10]['conversations']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WQunlnd6avP-"
      },
      "source": [
        "### Create a custom chat template\n",
        "\n",
        "Hugging Face supports chat templates that can be used to define the structure and format for converting conversations into a single tokenizable string, which is the input format expected by the language model. Check the [chat templates documentation](https://huggingface.co/docs/transformers/main/en/chat_templating) to know more about templates and how to create a custom new one.\n",
        "\n",
        "Since Gemma doesn't support system instructions, you will provide system input as user input. To read more about the format expected by Gemma, check out the [Gemma formatting doc](https://ai.google.dev/gemma/docs/formatting)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sDTk1-bcZz_2"
      },
      "outputs": [],
      "source": [
        "# Reference: https://github.com/unslothai/unsloth/blob/main/unsloth/chat_templates.py#L383\n",
        "\n",
        "chat_template = \\\n",
        "    \"{{ bos_token }}\"\\\n",
        "    \"{% if messages[0]['from'] == 'system' %}\"\\\n",
        "        \"{{'<start_of_turn>user\\n' + messages[0]['value'] | trim + ' ' + messages[1]['value'] | trim + '<end_of_turn>\\n'}}\"\\\n",
        "        \"{% set messages = messages[2:] %}\"\\\n",
        "    \"{% endif %}\"\\\n",
        "    \"{% for message in messages %}\"\\\n",
        "        \"{% if message['from'] == 'human' %}\"\\\n",
        "            \"{{'<start_of_turn>user\\n' + message['value'] | trim + '<end_of_turn>\\n'}}\"\\\n",
        "        \"{% elif message['from'] == 'gpt' %}\"\\\n",
        "            \"{{'<start_of_turn>model\\n' + message['value'] | trim + '<end_of_turn>\\n' }}\"\\\n",
        "        \"{% endif %}\"\\\n",
        "    \"{% endfor %}\"\\\n",
        "    \"{% if add_generation_prompt %}\"\\\n",
        "        \"{{ '<start_of_turn>model\\n' }}\"\\\n",
        "    \"{% endif %}\"\n",
        "\n",
        "tokenizer.chat_template = chat_template"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-N4h1qU-d3ow"
      },
      "source": [
        "### Define the formatting function\n",
        "\n",
        "The formatting function applies the template created above to each row in the dataset and converts it into a format suited for training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ja3EU-OKZz_4"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "ebe13895d65f4f1bbd78a562bd3dbe1a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Map:   0%|          | 0/16944 [00:00<?, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "def formatting_prompts_func(examples):\n",
        "    convos = examples[\"conversations\"]\n",
        "    texts = [tokenizer.apply_chat_template(convo, tokenize = False,\n",
        "                      add_generation_prompt = False) for convo in convos]\n",
        "    return { \"text\" : texts, }\n",
        "\n",
        "dataset = dataset.map(formatting_prompts_func, batched = True,)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vaqGj9J-eafu"
      },
      "source": [
        "### Clean up the dataset.\n",
        "\n",
        "Remove unnecessary tokens from the dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oA8_FWXgZz_5"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "                                                                                                                                                                                                                                                                                                                                                                 chat  \\\n",
            "0  USER: Hi, I have a list of numbers and I need to find the median. The numbers are 5, 2, 9, 1, 7, 4, 6, 3, 8.\\n\\n\\nASSISTANT: <functioncall> {\"name\": \"calculate_median\", \"arguments\": '{\"numbers\": [5, 2, 9, 1, 7, 4, 6, 3, 8]}'} <|endoftext|>\\n\\n\\nFUNCTION RESPONSE: {\"median\": 5}\\n\\n\\nASSISTANT: The median of your list of numbers is 5. <|endoftext|>\\n\\n\\n   \n",
            "\n",
            "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     system  \\\n",
            "0  SYSTEM: You are a helpful assistant with access to the following functions. Use them if required -\\n{\\n    \"name\": \"calculate_median\",\\n    \"description\": \"Calculate the median of a list of numbers\",\\n    \"parameters\": {\\n        \"type\": \"object\",\\n        \"properties\": {\\n            \"numbers\": {\\n                \"type\": \"array\",\\n                \"items\": {\\n                    \"type\": \"number\"\\n                },\\n                \"description\": \"The list of numbers\"\\n            }\\n        },\\n        \"required\": [\\n            \"numbers\"\\n        ]\\n    }\\n}\\n   \n",
            "\n",
            "  __hfsplit__  \\\n",
            "0       train   \n",
            "\n",
            "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         conversations  \\\n",
            "0  [{'from': 'system', 'value': 'You are a helpful assistant with access to the following functions. Use them if required -\n",
            "{\n",
            "    \"name\": \"calculate_median\",\n",
            "    \"description\": \"Calculate the median of a list of numbers\",\n",
            "    \"parameters\": {\n",
            "        \"type\": \"object\",\n",
            "        \"properties\": {\n",
            "            \"numbers\": {\n",
            "                \"type\": \"array\",\n",
            "                \"items\": {\n",
            "                    \"type\": \"number\"\n",
            "                },\n",
            "                \"description\": \"The list of numbers\"\n",
            "            }\n",
            "        },\n",
            "        \"required\": [\n",
            "            \"numbers\"\n",
            "        ]\n",
            "    }\n",
            "}\n",
            "'}, {'from': 'human', 'value': 'Hi, I have a list of numbers and I need to find the median. The numbers are 5, 2, 9, 1, 7, 4, 6, 3, 8.'}, {'from': 'gpt', 'value': '<functioncall> {\"name\": \"calculate_median\", \"arguments\": '{\"numbers\": [5, 2, 9, 1, 7, 4, 6, 3, 8]}'} <|endoftext|>'}, {'from': 'tool', 'value': '{\"median\": 5}'}, {'from': 'gpt', 'value': 'The median of your list of numbers is 5. <|endoftext|>'}]   \n",
            "\n",
            "                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     text  \n",
            "0  <bos><start_of_turn>user\\nYou are a helpful assistant with access to the following functions. Use them if required -\\n{\\n    \"name\": \"calculate_median\",\\n    \"description\": \"Calculate the median of a list of numbers\",\\n    \"parameters\": {\\n        \"type\": \"object\",\\n        \"properties\": {\\n            \"numbers\": {\\n                \"type\": \"array\",\\n                \"items\": {\\n                    \"type\": \"number\"\\n                },\\n                \"description\": \"The list of numbers\"\\n            }\\n        },\\n        \"required\": [\\n            \"numbers\"\\n        ]\\n    }\\n} Hi, I have a list of numbers and I need to find the median. The numbers are 5, 2, 9, 1, 7, 4, 6, 3, 8.<end_of_turn>\\n<start_of_turn>model\\n<functioncall> {\"name\": \"calculate_median\", \"arguments\": '{\"numbers\": [5, 2, 9, 1, 7, 4, 6, 3, 8]}'} <end_of_turn>\\n<start_of_turn>model\\nThe median of your list of numbers is 5. <end_of_turn>\\n  \n"
          ]
        }
      ],
      "source": [
        "import pandas as pd\n",
        "\n",
        "df_train = pd.DataFrame(dataset)\n",
        "df_train[\"text\"] = df_train[\"text\"].apply(\n",
        "    lambda x: x.replace(\"<|endoftext|>\", \"\"))\n",
        "\n",
        "pd.set_option('display.max_colwidth', None)\n",
        "print(df_train.head(1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dJ4IGfcsey7l"
      },
      "source": [
        "Convert the dataset back to Hugging Face's `Dataset` format."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vSkEjrEOfBjD"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "Dataset({\n",
              "    features: ['text'],\n",
              "    num_rows: 16944\n",
              "})"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "dataset = Dataset.from_pandas(df_train[['text']])\n",
        "\n",
        "dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qDMvzby5fcn_"
      },
      "source": [
        "### LoRA configuration\n",
        "\n",
        "LoRA(Low-Rank Adaptation) introduces small, trainable matrices into the model's architecture, specifically targeting the attention layers of Transformer models. Instead of updating the full weight matrices, LoRA adds rank-decomposed matrices, making adaptation more efficient.\n",
        "\n",
        "Here, you set the following parameters:\n",
        "- `r` to 16, which controls the rank of the adaptation matrices.\n",
        "- `lora_alpha` to 16 for scaling.\n",
        "- `lora_dropout` to 0 since it is optimized."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TeMIxx8GZz_5"
      },
      "outputs": [],
      "source": [
        "from peft import LoraConfig, PeftModel\n",
        "\n",
        "# Load LoRA configuration\n",
        "peft_config = LoraConfig(\n",
        "    lora_alpha=16,       # Alpha parameter for LoRA scaling\n",
        "    lora_dropout=0,    # Dropout probability for LoRA layers\n",
        "    r=16,                # LoRA attention dimension\n",
        "    bias=\"none\",\n",
        "    task_type=\"CAUSAL_LM\",\n",
        "    target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
        "                      \"gate_proj\", \"up_proj\", \"down_proj\",]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jmOGm36Qg8O4"
      },
      "source": [
        "The **Fully Sharded Data Parallel (FSDP)** configuration is set up in `fsdp_config`, enabling [**full model sharding**](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy) and [**gradient checkpointing**](https://huggingface.co/docs/transformers/v4.19.4/en/performance#gradient-checkpointing) for memory efficiency on TPUs, and specifying that gradient checkpointing should be enabled with `xla_fsdp_grad_ckpt`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9ixCCNgcg85g"
      },
      "outputs": [],
      "source": [
        "# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.\n",
        "fsdp_config = {\n",
        "    \"fsdp_transformer_layer_cls_to_wrap\": [\n",
        "        \"Gemma2DecoderLayer\"\n",
        "    ],\n",
        "    \"xla\": True,\n",
        "    \"xla_fsdp_v2\": True,\n",
        "    \"xla_fsdp_grad_ckpt\": True\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tmISNVMYf8vZ"
      },
      "source": [
        "### Set training configuration\n",
        "\n",
        "Set up the training arguments that define how the model will be trained.\n",
        "\n",
        "Here, you'll define the following parameters:\n",
        "\n",
        "- For training:\n",
        "  - `output directory`\n",
        "  - `max steps`\n",
        "  - `batch sizes`\n",
        "\n",
        "- To optimize the training process:\n",
        "  - `learning rate`\n",
        "  - `optimizer`\n",
        "  - `learning rate scheduler`\n",
        "\n",
        "**Note:** `max_steps` is set as 100 steps to speed things up, but you can set `num_train_epochs=1` for a full run."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pt69suWMZz_6"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:root:torch_xla.core.xla_model.xrt_world_size() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.world_size instead.\n",
            "WARNING:root:torch_xla.core.xla_model.xla_model.get_ordinal() will be removed in release 2.7. is deprecated. Use torch_xla.runtime.global_ordinal instead.\n"
          ]
        }
      ],
      "source": [
        "from trl import SFTTrainer, SFTConfig\n",
        "\n",
        "# Set training parameters\n",
        "training_arguments = SFTConfig(\n",
        "    # ---Output settings--\n",
        "    # Output directory where model predictions and checkpoints will be stored\n",
        "    output_dir=\"./results\",\n",
        "    overwrite_output_dir=True,\n",
        "    save_strategy=\"no\",\n",
        "    # ---Training settings---\n",
        "    # Number of training epochs\n",
        "    #num_train_epochs=1,\n",
        "    # Number of training steps (overrides num_train_epochs)\n",
        "    max_steps=100,\n",
        "    # This is the global train batch size for SPMD\n",
        "    # Batch size per GPU core for training\n",
        "    per_device_train_batch_size=32,\n",
        "    # Number of update steps to accumulate the gradients for\n",
        "    gradient_accumulation_steps=1,\n",
        "    # Optimizer to use\n",
        "    optim=\"adafactor\",\n",
        "    # Required for SPMD\n",
        "    dataloader_drop_last=True,\n",
        "    fsdp=\"full_shard\",\n",
        "    fsdp_config=fsdp_config,\n",
        "    # Initial learning rate (adafactor optimizer)\n",
        "    learning_rate=0.0002,\n",
        "    # Enable bfloat16 precision\n",
        "    bf16=True,\n",
        "    # Maximum gradient normal (gradient clipping)\n",
        "    max_grad_norm=0.3,\n",
        "    # Ratio of steps for a linear warmup (from 0 to learning rate)\n",
        "    warmup_ratio=0.03,\n",
        "    # Learning rate schedule (constant a bit better than cosine)\n",
        "    lr_scheduler_type=\"linear\",\n",
        "    # Maximum sequence length to use\n",
        "    max_seq_length=1024,\n",
        "    dataset_text_field=\"text\",\n",
        "    dataset_kwargs={\n",
        "        \"add_special_tokens\": False,\n",
        "        \"append_concat_token\": False,\n",
        "    },\n",
        "    # Pack multiple short examples in the same input sequence\n",
        "    # to increase efficiency\n",
        "    packing=True,\n",
        "    # ---Logging---\n",
        "    # Log every X update step\n",
        "    logging_steps=1,\n",
        "    report_to=\"none\",\n",
        "    seed=42\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YATzmXmxgg5q"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "[Huggingface's TRL](https://huggingface.co/docs/trl/index) offers a user-friendly API for building SFT models and training them on your dataset with just a few lines of code. Here you will use Huggingface TRL's `SFTTrainer` class to train the model. This class inherits from the `Trainer` class available in the Transformers library but is specifically optimized for supervised fine-tuning (instruction tuning). Read more about SFFTrainer from the [official TRL SFT docs](https://huggingface.co/docs/trl/sft_trainer)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ftyGnb5KZz_6"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "2001cbfa83d641c3b2f75d4163bbd235",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Generating train split: 0 examples [00:00, ? examples/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:403: UserWarning: You passed a processing_class with `padding_side` not equal to `right` to the SFTTrainer. This might lead to some unexpected behaviour due to overflow issues when training a model in half-precision. You might consider adding `processing_class.padding_side = 'right'` to your code.\n",
            "  warnings.warn(\n",
            "max_steps is given, it will override any value given in num_train_epochs\n",
            "/usr/local/lib/python3.10/dist-packages/trl/trainer/sft_trainer.py:428: UserWarning: You passed `packing=True` to the SFTTrainer/SFTConfig, and you are training your model with `max_steps` strategy. The dataset will be iterated until the `max_steps` are reached.\n",
            "  warnings.warn(\n"
          ]
        }
      ],
      "source": [
        "# Set supervised fine-tuning parameters\n",
        "trainer = SFTTrainer(\n",
        "    model=model,\n",
        "    train_dataset=dataset,\n",
        "    peft_config=peft_config,\n",
        "    args=training_arguments\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RCIBn9TCgxmZ"
      },
      "source": [
        "Now, let's start the fine-tuning process by calling `trainer.train()`, which uses `SFTTrainer` to handle the training loop, including data loading, forward and backward passes, and optimizer steps, all configured according to the settings you've provided."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XPz4ia15Zz_7"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1810: UserWarning: For backward hooks to be called, module output should be a Tensor or a tuple of Tensors but received <class 'transformers.modeling_outputs.CausalLMOutputWithPast'>\n",
            "  warnings.warn(\"For backward hooks to be called,\"\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:183: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.\n",
            "  torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \\\n",
            "/usr/local/lib/python3.10/dist-packages/torch_xla/utils/checkpoint.py:184: FutureWarning: `torch.cpu.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cpu', args...)` instead.\n",
            "  torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "\n",
              "    <div>\n",
              "      \n",
              "      <progress value='100' max='100' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
              "      [100/100 13:22, Epoch 0/1]\n",
              "    </div>\n",
              "    <table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              " <tr style=\"text-align: left;\">\n",
              "      <th>Step</th>\n",
              "      <th>Training Loss</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <td>1</td>\n",
              "      <td>2.093800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>2</td>\n",
              "      <td>2.125000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>3</td>\n",
              "      <td>2.078100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>4</td>\n",
              "      <td>1.765600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>5</td>\n",
              "      <td>1.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>6</td>\n",
              "      <td>1.218800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>7</td>\n",
              "      <td>1.164100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>8</td>\n",
              "      <td>1.039100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>9</td>\n",
              "      <td>1.125000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>10</td>\n",
              "      <td>0.953100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>11</td>\n",
              "      <td>0.882800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>12</td>\n",
              "      <td>0.875000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>13</td>\n",
              "      <td>0.882800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>14</td>\n",
              "      <td>0.800800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>15</td>\n",
              "      <td>0.789100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>16</td>\n",
              "      <td>0.793000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>17</td>\n",
              "      <td>0.808600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>18</td>\n",
              "      <td>0.832000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>19</td>\n",
              "      <td>0.761700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>20</td>\n",
              "      <td>0.730500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>21</td>\n",
              "      <td>0.796900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>22</td>\n",
              "      <td>0.750000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>23</td>\n",
              "      <td>0.746100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>24</td>\n",
              "      <td>0.671900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>25</td>\n",
              "      <td>0.699200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>26</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>27</td>\n",
              "      <td>0.668000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>28</td>\n",
              "      <td>0.761700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>29</td>\n",
              "      <td>0.597700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>30</td>\n",
              "      <td>0.625000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>31</td>\n",
              "      <td>0.742200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>32</td>\n",
              "      <td>0.660200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>33</td>\n",
              "      <td>0.578100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>34</td>\n",
              "      <td>0.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>35</td>\n",
              "      <td>0.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>36</td>\n",
              "      <td>0.625000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>37</td>\n",
              "      <td>0.558600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>38</td>\n",
              "      <td>0.660200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>39</td>\n",
              "      <td>0.609400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>40</td>\n",
              "      <td>0.582000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>41</td>\n",
              "      <td>0.605500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>42</td>\n",
              "      <td>0.535200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>43</td>\n",
              "      <td>0.621100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>44</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>45</td>\n",
              "      <td>0.585900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>46</td>\n",
              "      <td>0.640600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>47</td>\n",
              "      <td>0.566400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>48</td>\n",
              "      <td>0.605500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>49</td>\n",
              "      <td>0.511700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>50</td>\n",
              "      <td>0.535200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>51</td>\n",
              "      <td>0.566400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>52</td>\n",
              "      <td>0.566400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>53</td>\n",
              "      <td>0.593800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>54</td>\n",
              "      <td>0.515600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>55</td>\n",
              "      <td>0.582000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>56</td>\n",
              "      <td>0.613300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>57</td>\n",
              "      <td>0.648400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>58</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>59</td>\n",
              "      <td>0.609400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>60</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>61</td>\n",
              "      <td>0.589800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>62</td>\n",
              "      <td>0.597700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>63</td>\n",
              "      <td>0.601600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>64</td>\n",
              "      <td>0.582000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>65</td>\n",
              "      <td>0.644500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>66</td>\n",
              "      <td>0.539100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>67</td>\n",
              "      <td>0.640600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>68</td>\n",
              "      <td>0.507800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>69</td>\n",
              "      <td>0.609400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>70</td>\n",
              "      <td>0.585900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>71</td>\n",
              "      <td>0.613300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>72</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>73</td>\n",
              "      <td>0.636700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>74</td>\n",
              "      <td>0.515600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>75</td>\n",
              "      <td>0.554700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>76</td>\n",
              "      <td>0.562500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>77</td>\n",
              "      <td>0.574200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>78</td>\n",
              "      <td>0.562500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>79</td>\n",
              "      <td>0.617200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>80</td>\n",
              "      <td>0.543000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>81</td>\n",
              "      <td>0.664100</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>82</td>\n",
              "      <td>0.562500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>83</td>\n",
              "      <td>0.488300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>84</td>\n",
              "      <td>0.488300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>85</td>\n",
              "      <td>0.660200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>86</td>\n",
              "      <td>0.609400</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>87</td>\n",
              "      <td>0.511700</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>88</td>\n",
              "      <td>0.507800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>89</td>\n",
              "      <td>0.535200</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>90</td>\n",
              "      <td>0.625000</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>91</td>\n",
              "      <td>0.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>92</td>\n",
              "      <td>0.550800</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>93</td>\n",
              "      <td>0.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>94</td>\n",
              "      <td>0.605500</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>95</td>\n",
              "      <td>0.515600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>96</td>\n",
              "      <td>0.558600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>97</td>\n",
              "      <td>0.570300</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>98</td>\n",
              "      <td>0.515600</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>99</td>\n",
              "      <td>0.546900</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <td>100</td>\n",
              "      <td>0.484400</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table><p>"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/torch_xla/core/xla_model.py:1457: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
            "  xldata.append(torch.load(xbio))\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "TrainOutput(global_step=100, training_loss=0.708515625, metrics={'train_runtime': 849.5865, 'train_samples_per_second': 3.767, 'train_steps_per_second': 0.118, 'total_flos': 5.18083433201664e+16, 'train_loss': 0.708515625, 'epoch': 0.3861003861003861})"
            ]
          },
          "execution_count": 17,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "trainer.train()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3C4X9QUZhGOE"
      },
      "source": [
        "After training is complete, you save the fine-tuned model by moving it to the CPU with `trainer.model.to('cpu')` to ensure compatibility and then calling `save_pretrained(new_model)` to save the model weights and configuration files to the directory specified by `new_model` (**gemma-func-ft**). This allows you to reload and use the fine-tuned model later for inference or further training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-X2vfpRYZz_7"
      },
      "outputs": [],
      "source": [
        "# Remove the model weights directory if it exists\n",
        "!rm -rf gemma-func-ft\n",
        "\n",
        "# Save the LoRA adapter\n",
        "trainer.model.to('cpu').save_pretrained(new_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VwjYKWW9hT0D"
      },
      "source": [
        "## Prompt using the newly fine-tuned model\n",
        "\n",
        "\n",
        "Now that you've finally fine-tuned your custom Gemma model, let's reload the LoRA adapter weights to finally prompt using it and also verify if it's really working as intended."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KiwO7V7LhUj8"
      },
      "source": [
        "To do this, use the following steps to correctly reload the adapter weights:\n",
        "\n",
        "- Use `AutoModelForCausalLM.from_pretrained` to first load the **base Gemma model**, while setting `low_cpu_mem_usage=True` to optimize memory consumption (since you're using a TPU) and `torch_dtype=torch.bfloat16` for consistency with the fine-tuned model.\n",
        "\n",
        "- Load the **fine-tuned LoRA adapter** that you've previously saved into the base model using `PeftModel.from_pretrained`, where `new_model` is the directory containing your fine-tuned weights.\n",
        "\n",
        "- The `model.merge_and_unload` function **merges** the **LoRA adapter weights** with the **base model weights** and unloads the adapter, resulting in a standalone model ready for inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c9rLNU-dZz_8"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "b1411b88d40b445382886110923d5044",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Reload the fine-tuned Gemma model\n",
        "base_model = AutoModelForCausalLM.from_pretrained(\n",
        "    model_name,\n",
        "    low_cpu_mem_usage=True,\n",
        "    return_dict=True,\n",
        "    torch_dtype=torch.bfloat16\n",
        ")\n",
        "model = PeftModel.from_pretrained(base_model, new_model)\n",
        "model = model.merge_and_unload()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qvbZVUB1hduS"
      },
      "source": [
        "Reload the tokenizer to ensure it matches the model configuration, adjusting the padding side as before."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RipNS_4JZz_8"
      },
      "outputs": [],
      "source": [
        "# Reload tokenizer\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "tokenizer.padding_side = \"right\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5K7wfhcGhmXB"
      },
      "source": [
        "Now, test the fine-tuned model with a sample prompt by first using the tokenizer to generate the input ids, and then relying on the reloaded fine-tuned model to generate a response using `model.generate()`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e4In0uDCZz_-"
      },
      "outputs": [],
      "source": [
        "input_text = \"\"\"\\\n",
        "<start_of_turn>user\n",
        "You are a helpful assistant with access to the following functions. Use them if required -\n",
        "{\n",
        "    \"name\": \"calculate_median\",\n",
        "    \"description\": \"Calculate the median of a list of numbers\",\n",
        "    \"parameters\": {\n",
        "        \"type\": \"object\",\n",
        "        \"properties\": {\n",
        "             \"numbers\": {\n",
        "                 \"type\": \"array\",\n",
        "                 \"items\": {\n",
        "                     \"type\": \"number\"\n",
        "                 },\n",
        "                 \"description\": \"The list of numbers\"\n",
        "             }\n",
        "        }\n",
        "        \"required\": [\n",
        "            \"numbers\"\n",
        "        ]\n",
        "    }\n",
        "}\n",
        "To use these functions respond with:\n",
        "<functioncall> {\"name\": \"function_name\", \"arguments\": {\"arg_1\": \"value_1\", \"arg_1\": \"value_1\", ...}} </functioncall>\n",
        "\n",
        "Then finally respond with:\n",
        "Answer:\n",
        "\n",
        "<end_of_turn>\n",
        "<start_of_turn>user\n",
        "USER: Hi, I have a list of numbers and I need to find the median. The numbers are [5, 2, 9, 1, 7, 4, 6, 3, 8]\n",
        "<end_of_turn>\n",
        "<start_of_turn>model\n",
        "<functioncall>\n",
        "\"\"\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f5SmpC3JZz__"
      },
      "outputs": [],
      "source": [
        "input_ids = tokenizer(input_text, return_tensors=\"pt\").to(\"cpu\")\n",
        "outputs = model.generate(**input_ids, max_new_tokens = 512)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "t4YF9rRZhrlR"
      },
      "source": [
        "Finally, you decode the output tokens back into human-readable text with `tokenizer.decode` and print the result, allowing you to see how the fine-tuned model responds to the prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-ndkbdPJZz__"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "<bos><start_of_turn>user\n",
            "You are a helpful assistant with access to the following functions. Use them if required -\n",
            "{\n",
            "    \"name\": \"calculate_median\",\n",
            "    \"description\": \"Calculate the median of a list of numbers\",\n",
            "    \"parameters\": {\n",
            "        \"type\": \"object\",\n",
            "        \"properties\": { \n",
            "             \"numbers\": {\n",
            "                 \"type\": \"array\",\n",
            "                 \"items\": {\n",
            "                     \"type\": \"number\"              \n",
            "                 },\n",
            "                 \"description\": \"The list of numbers\"\n",
            "             }      \n",
            "        }       \n",
            "        \"required\": [\n",
            "            \"numbers\"       \n",
            "        ]    \n",
            "    }\n",
            "}\n",
            "To use these functions respond with:\n",
            "<functioncall> {\"name\": \"function_name\", \"arguments\": {\"arg_1\": \"value_1\", \"arg_1\": \"value_1\", ...}} </functioncall>\n",
            "\n",
            "Then finally respond with:\n",
            "Answer:\n",
            "\n",
            "<end_of_turn>\n",
            "<start_of_turn>user\n",
            "USER: Hi, I have a list of numbers and I need to find the median. The numbers are [5, 2, 9, 1, 7, 4, 6, 3, 8]\n",
            "<end_of_turn>\n",
            "<start_of_turn>model\n",
            "<functioncall>\n",
            "{\"name\": \"calculate_median\", \"arguments\": {\"numbers\": [5, 2, 9, 1, 7, 4, 6, 3, 8]}}\n",
            "</functioncall> \n",
            "\n",
            "Answer: \n",
            "```\n",
            "The median of the list [5, 2, 9, 1, 7, 4, 6, 3, 8] is 4. \n",
            "``` \n",
            "<end_of_turn>\n"
          ]
        }
      ],
      "source": [
        "print(tokenizer.decode(outputs[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-U9FWCfkh2QK"
      },
      "source": [
        "Congratulations! You've successfully fine-tuned Gemma for Function Calling using Torch XLA and PEFT with LoRA on TPUs. With that, you've covered the entire process, from setting up the environment to training and testing the model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TqXpz36GhyYS"
      },
      "source": [
        "## What's next?\n",
        "Your next steps could include the following:\n",
        "\n",
        "- **Experiment with Different Datasets**: Try fine-tuning on other function calling datasets in [Hugging Face](https://huggingface.co/docs/datasets/en/index) or your own data.\n",
        "\n",
        "- **Tune Hyperparameters**: Adjust training parameters (e.g., learning rate, batch size, epochs, LoRA settings) to optimize performance and\n",
        "improve training efficiency.\n",
        "\n",
        "- **Try different templates**: Try different chat templates and try to improve the performance.\n",
        "\n",
        "By exploring these activities, you'll deepen your understanding and further enhance your fine-tuned Gemma model. Happy experimenting!"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "[Gemma_2]Finetune_with_Function_Calling.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
